import numpy as np
from sklearn.metrics import pairwise_distances_argmin_min as dist

import math
import random

def cost(data, centers):
    # Calculate the clustering cost based on the sum of squared distances from each point to its nearest center times its weight
    _, distance = dist(data, centers)
    result = np.sum(distance ** 2)
    return result

def net_construction(data, smallest_net_distance = 0.1):
    dimension = data.shape[1]

    # Get the range of the data space
    min_values = np.min(data, axis=0)
    max_values = np.max(data, axis=0)
    range_values = max_values - min_values
    delta = np.max(range_values, axis=0)

    M = int(math.log(delta / (smallest_net_distance), 2))
    if M < 0:
        M = 2
    net_distance = smallest_net_distance
    final_centers = data

    for i in range(M):
        offsets = np.array(np.meshgrid(*[[-1 * net_distance, 0, 1 * net_distance]] * dimension)).T.reshape(-1, dimension)
        new_centers = (data[:, None, :] + offsets).reshape(-1, dimension)
        mask = np.all((new_centers >= min_values) & (new_centers <= max_values), axis=1)
        in_range_centers = new_centers[mask]
        final_centers = np.concatenate((final_centers, in_range_centers))
        final_centers = np.unique(final_centers, axis=0)

        net_distance *= 2

    
    return final_centers, M

def center_fixed_local_search(data, net, k, s=1, n_iter=10, random_seed = 0):
    random.seed(random_seed)
    indices = np.random.choice(data.shape[0], size=k, replace=False)
    centers = data[indices]

    improved_rounds = 0
    goal_rounds = k * n_iter
    attempt_rounds = 0
    max_rounds = goal_rounds * n_iter
    
    while improved_rounds < goal_rounds and attempt_rounds < max_rounds:
        attempt_rounds += 1

        current_cost = cost(data, centers)
        new_centers = np.copy(centers)
        center_idx = random.randint(s, len(centers) - 1)
        net_index = random.randint(0, len(net) - 1)
        new_centers[center_idx] = net[net_index]

        new_cost = cost(data, new_centers)
        if new_cost < current_cost:
            centers = new_centers
            improved_rounds += 1

    assignment, _ = dist(data, centers)
    current_cost = cost(data, centers)

    return current_cost, centers, assignment

def k_means_algorithm_uniform(data, k, s=1, attempt = 100, n_iter=10, net_distance = 0.1, random_seed = 0):
    net, M = net_construction(data, smallest_net_distance=net_distance)
    current_cost, current_centers, current_assignment = center_fixed_local_search(data, net, k, s, n_iter = n_iter, random_seed = random_seed)
    for i in range(attempt):
        new_cost, new_centers, new_assignment = center_fixed_local_search(data, net, k, s, n_iter, random_seed = random_seed)
        if new_cost < current_cost:
            current_cost = new_cost
            current_centers = new_centers
            current_assignment = new_assignment
        
    return current_cost, current_centers, current_assignment